{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://storage.googleapis.com/gweb-uniblog-publish-prod/images/gemma-header.width-1200.format-webp.webp\" width=\"100%\">\n",
    "\n",
    "## Instruct Fine-tuning [Gemma](https://blog.google/technology/developers/gemma-open-models/) using qLora and Supervise Finetuning\n",
    "\n",
    "This is a comprahensive notebook and tutorial on how to fine tune the `gemma-7b-it` Model\n",
    "\n",
    "All the code will be available on my Github. Do drop by and give a follow and a star.\n",
    "[adithya-s-k](https://github.com/adithya-s-k)\n",
    "\\\n",
    "[Github Code](https://github.com/adithya-s-k/LLM-Cookbook/blob/main/LLMs/Gemma/finetune-gemma.ipynb)\n",
    "\n",
    "I also post content about LLMs and what I have been working on Twitter.\n",
    "[AdithyaSK (@adithya_s_k) / X](https://twitter.com/adithya_s_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "Before delving into the fine-tuning process, ensure that you have the following prerequisites in place:\n",
    "\n",
    "1. **GPU**: [gemma-2b](https://huggingface.co/google/gemma-2b) - can be finetuned on T4(free google colab) while [gemma-7b](https://huggingface.co/google/gemma-7b) requires an A100 GPU.\n",
    "2. **Python Packages**: Ensure that you have the necessary Python packages installed. You can use the following commands to install them:\n",
    "\n",
    "Let's begin by checking if your GPU is correctly detected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wed Feb 21 17:19:05 2024       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA A100 80GB PCIe          On  | 00000001:00:00.0 Off |                    0 |\n",
      "| N/A   33C    P0              42W / 300W |      4MiB / 81920MiB |      0%      Default |\n",
      "|                                         |                      |             Disabled |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "|  No running processes found                                                           |\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Model loading\n",
    "We'll load the model using QLoRA quantization to reduce the usage of memory\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install -q -U bitsandbytes==0.42.0\n",
    "!pip3 install -q -U peft==0.8.2\n",
    "!pip3 install -q -U trl==0.7.10\n",
    "!pip3 install -q -U accelerate==0.27.1\n",
    "!pip3 install -q -U datasets==2.17.0\n",
    "!pip3 install -q -U transformers==4.38.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we specify the model ID and then we load it with our previously defined quantization configuration.Now we specify the model ID and then we load it with our previously defined quantization configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are using google colab\n",
    "\n",
    "# import os\n",
    "# from google.colab import userdata\n",
    "# os.environ[\"HF_TOKEN\"] = userdata.get('HF_TOKEN')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = \"google/gemma-7b-it\"\n",
    "# model_id = \"google/gemma-7b\"\n",
    "# model_id = \"google/gemma-2b-it\"\n",
    "# model_id = \"google/gemma-2b\"\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_completion(query: str, model, tokenizer) -> str:\n",
    "  device = \"cuda:0\"\n",
    "\n",
    "  prompt_template = \"\"\"\n",
    "  <start_of_turn>user\n",
    "  Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
    "  {query}\n",
    "  <end_of_turn>\\n<start_of_turn>model\n",
    "  \n",
    "\n",
    "  \"\"\"\n",
    "  prompt = prompt_template.format(query=query)\n",
    "\n",
    "  encodeds = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=True)\n",
    "\n",
    "  model_inputs = encodeds.to(device)\n",
    "\n",
    "\n",
    "  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)\n",
    "  # decoded = tokenizer.batch_decode(generated_ids)\n",
    "  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
    "  return (decoded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "  user\n",
      "  Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
      "  code the fibonacci series in python using reccursion\n",
      "  \n",
      "model\n",
      "  \n",
      "\n",
      "   a Python function to calculate the nth Fibonacci number using recursion. Here's the code: \n",
      "\n",
      "```python\n",
      "\n",
      "def fibonacci(n):\n",
      "  if n == 0:\n",
      "    return 0\n",
      "  elif n == 1:\n",
      "    return 1\n",
      "  else:\n",
      "    return fibonacci(n-1) + fibonacci(n-2)\n",
      "\n",
      "# Print the nth Fibonacci number\n",
      "print(fibonacci(10))\n",
      "```\n",
      "\n",
      "**Explanation:**\n",
      "\n",
      "1. The function `fibonacci` takes an integer `n` as input.\n",
      "2. If `n` is 0 or 1, it returns the respective base case of 0 or 1.\n",
      "3. Otherwise, it recursively calculates the Fibonacci number for `n-1` and `n-2` and adds their sum to return the Fibonacci number for `n`.\n",
      "4. The function continues to recurse until `n` reaches the desired number, and the final result is returned.\n",
      "\n",
      "**Output:**\n",
      "\n",
      "```\n",
      ">>> print(fibonacci(10))\n",
      "5\n",
      "```\n",
      "\n",
      "In this example, the code calculates the 10th Fibonacci number, which is 5.\n",
      "\n",
      "**Note:** Reccursion can be a powerful technique for solving problems that involve repeated calculations. However, it is important to note that recursion can also lead to stack overflow errors for large values of `n` due to its repeated function calls. For more efficient solutions, iterative approaches are often preferred.\n"
     ]
    }
   ],
   "source": [
    "result = get_completion(query=\"code the fibonacci series in python using reccursion\", model=model, tokenizer=tokenizer)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Load dataset for finetuning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Lets Load the Dataset\n",
    "\n",
    "For this tutorial, we will fine-tune Mistral 7B Instruct for code generation.\n",
    "\n",
    "We will be using this [dataset](https://huggingface.co/datasets/TokenBender/code_instructions_122k_alpaca_style) which is curated by [TokenBender (e/xperiments)](https://twitter.com/4evaBehindSOTA) and is an excellent data source for fine-tuning models for code generation. It follows the alpaca style of instructions, which is an excellent starting point for this task. The dataset structure should resemble the following:\n",
    "\n",
    "```json\n",
    "{\n",
    "  \"instruction\": \"Create a function to calculate the sum of a sequence of integers.\",\n",
    "  \"input\": \"[1, 2, 3, 4, 5]\",\n",
    "  \"output\": \"# Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum\"\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"TokenBender/code_instructions_122k_alpaca_style\", split=\"train\")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input</th>\n",
       "      <th>output</th>\n",
       "      <th>text</th>\n",
       "      <th>instruction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[1, 2, 3, 4, 5]</td>\n",
       "      <td># Python code\\ndef sum_sequence(sequence):\\n  ...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Create a function to calculate the sum of a se...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>str1 = \"Hello \"\\nstr2 = \"world\"</td>\n",
       "      <td>def add_strings(str1, str2):\\n    \"\"\"This func...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Develop a function that will add two strings</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td></td>\n",
       "      <td>#include &lt;map&gt;\\n#include &lt;string&gt;\\n\\nclass Gro...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Design a data structure in C++ to store inform...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[3, 1, 4, 5, 9, 0]</td>\n",
       "      <td>def bubble_sort(arr):\\n    n = len(arr)\\n \\n  ...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Implement a sorting algorithm to sort a given ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Not applicable</td>\n",
       "      <td>import UIKit\\n\\nclass ExpenseViewController: U...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Design a Swift application for tracking expens...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Not Applicable</td>\n",
       "      <td>&lt;?php\\n$timestamp = $_GET['timestamp'];\\n\\nif(...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Create a REST API to convert a UNIX timestamp ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>website: www.example.com \\ndata to crawl: phon...</td>\n",
       "      <td>import requests\\nimport re\\n\\ndef crawl_websit...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Generate a Python code for crawling a website ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td></td>\n",
       "      <td>[x*x for x in [1, 2, 3, 5, 8, 13]]</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Create a Python list comprehension to get the ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td></td>\n",
       "      <td>SELECT * FROM products ORDER BY price DESC LIM...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Create a MySQL query to find the most expensiv...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Not applicable</td>\n",
       "      <td>public class Library {\\n \\n // map of books in...</td>\n",
       "      <td>Below is an instruction that describes a task....</td>\n",
       "      <td>Create a data structure in Java for storing an...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               input  \\\n",
       "0                                    [1, 2, 3, 4, 5]   \n",
       "1                    str1 = \"Hello \"\\nstr2 = \"world\"   \n",
       "2                                                      \n",
       "3                                 [3, 1, 4, 5, 9, 0]   \n",
       "4                                     Not applicable   \n",
       "5                                     Not Applicable   \n",
       "6  website: www.example.com \\ndata to crawl: phon...   \n",
       "7                                                      \n",
       "8                                                      \n",
       "9                                     Not applicable   \n",
       "\n",
       "                                              output  \\\n",
       "0  # Python code\\ndef sum_sequence(sequence):\\n  ...   \n",
       "1  def add_strings(str1, str2):\\n    \"\"\"This func...   \n",
       "2  #include <map>\\n#include <string>\\n\\nclass Gro...   \n",
       "3  def bubble_sort(arr):\\n    n = len(arr)\\n \\n  ...   \n",
       "4  import UIKit\\n\\nclass ExpenseViewController: U...   \n",
       "5  <?php\\n$timestamp = $_GET['timestamp'];\\n\\nif(...   \n",
       "6  import requests\\nimport re\\n\\ndef crawl_websit...   \n",
       "7                 [x*x for x in [1, 2, 3, 5, 8, 13]]   \n",
       "8  SELECT * FROM products ORDER BY price DESC LIM...   \n",
       "9  public class Library {\\n \\n // map of books in...   \n",
       "\n",
       "                                                text  \\\n",
       "0  Below is an instruction that describes a task....   \n",
       "1  Below is an instruction that describes a task....   \n",
       "2  Below is an instruction that describes a task....   \n",
       "3  Below is an instruction that describes a task....   \n",
       "4  Below is an instruction that describes a task....   \n",
       "5  Below is an instruction that describes a task....   \n",
       "6  Below is an instruction that describes a task....   \n",
       "7  Below is an instruction that describes a task....   \n",
       "8  Below is an instruction that describes a task....   \n",
       "9  Below is an instruction that describes a task....   \n",
       "\n",
       "                                         instruction  \n",
       "0  Create a function to calculate the sum of a se...  \n",
       "1       Develop a function that will add two strings  \n",
       "2  Design a data structure in C++ to store inform...  \n",
       "3  Implement a sorting algorithm to sort a given ...  \n",
       "4  Design a Swift application for tracking expens...  \n",
       "5  Create a REST API to convert a UNIX timestamp ...  \n",
       "6  Generate a Python code for crawling a website ...  \n",
       "7  Create a Python list comprehension to get the ...  \n",
       "8  Create a MySQL query to find the most expensiv...  \n",
       "9  Create a data structure in Java for storing an...  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = dataset.to_pandas()\n",
    "df.head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instruction Fintuning - Prepare the dataset under the format of \"prompt\" so the model can better understand :\n",
    "1. the function generate_prompt : take the instruction and output and generate a prompt\n",
    "2. shuffle the dataset\n",
    "3. tokenizer the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Formatting the Dataset\n",
    "\n",
    "Now, let's format the dataset in the required [gemma instruction formate](https://huggingface.co/google/gemma-7b-it).\n",
    "\n",
    "> Many tutorials and blogs skip over this part, but I feel this is a really important step.\n",
    "\n",
    "```\n",
    "<start_of_turn>user What is your favorite condiment? <end_of_turn>\n",
    "<start_of_turn>model Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavor to whatever I'm cooking up in the kitchen!<end_of_turn>\n",
    "```\n",
    "\n",
    "You can use the following code to process your dataset and create a JSONL file in the correct format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_prompt(data_point):\n",
    "    \"\"\"Gen. input text based on a prompt, task instruction, (context info.), and answer\n",
    "\n",
    "    :param data_point: dict: Data point\n",
    "    :return: dict: tokenzed prompt\n",
    "    \"\"\"\n",
    "    prefix_text = 'Below is an instruction that describes a task. Write a response that ' \\\n",
    "               'appropriately completes the request.\\n\\n'\n",
    "    # Samples with additional context into.\n",
    "    if data_point['input']:\n",
    "        text = f\"\"\"<start_of_turn>user {prefix_text} {data_point[\"instruction\"]} here are the inputs {data_point[\"input\"]} <end_of_turn>\\n<start_of_turn>model{data_point[\"output\"]} <end_of_turn>\"\"\"\n",
    "    # Without\n",
    "    else:\n",
    "        text = f\"\"\"<start_of_turn>user {prefix_text} {data_point[\"instruction\"]} <end_of_turn>\\n<start_of_turn>model{data_point[\"output\"]} <end_of_turn>\"\"\"\n",
    "    return text\n",
    "\n",
    "# add the \"prompt\" column in the dataset\n",
    "text_column = [generate_prompt(data_point) for data_point in dataset]\n",
    "dataset = dataset.add_column(\"prompt\", text_column)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll need to tokenize our data so the model can understand.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.shuffle(seed=1234)  # Shuffle dataset here\n",
    "dataset = dataset.map(lambda samples: tokenizer(samples[\"prompt\"]), batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split dataset into 90% for training and 10% for testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.train_test_split(test_size=0.2)\n",
    "train_data = dataset[\"train\"]\n",
    "test_data = dataset[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### After Formatting, We should get something like this\n",
    "\n",
    "```json\n",
    "{\n",
    "\"text\":\"<start_of_turn>user Create a function to calculate the sum of a sequence of integers. here are the inputs [1, 2, 3, 4, 5] <end_of_turn>\n",
    "<start_of_turn>model # Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum <end_of_turn>\",\n",
    "\"instruction\":\"Create a function to calculate the sum of a sequence of integers\",\n",
    "\"input\":\"[1, 2, 3, 4, 5]\",\n",
    "\"output\":\"# Python code def sum_sequence(sequence): sum = 0 for num in,\n",
    " sequence: sum += num return sum\",\n",
    "\"prompt\":\"<start_of_turn>user Create a function to calculate the sum of a sequence of integers. here are the inputs [1, 2, 3, 4, 5] <end_of_turn>\n",
    "<start_of_turn>model # Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum <end_of_turn>\"\n",
    "\n",
    "}\n",
    "```\n",
    "\n",
    "While using SFT (**[Supervised Fine-tuning Trainer](https://huggingface.co/docs/trl/main/en/sft_trainer)**) for fine-tuning, we will be only passing in the “text” column of the dataset for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['input', 'output', 'text', 'instruction', 'prompt', 'input_ids', 'attention_mask'],\n",
      "    num_rows: 24392\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "print(test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Apply Lora  \n",
    "Here comes the magic with peft! Let's load a PeftModel and specify that we are going to use low-rank adapters (LoRA) using get_peft_model utility function and  the prepare_model_for_kbit_training method from PEFT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model\n",
    "model.gradient_checkpointing_enable()\n",
    "model = prepare_model_for_kbit_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GemmaForCausalLM(\n",
      "  (model): GemmaModel(\n",
      "    (embed_tokens): Embedding(256000, 3072, padding_idx=0)\n",
      "    (layers): ModuleList(\n",
      "      (0-27): 28 x GemmaDecoderLayer(\n",
      "        (self_attn): GemmaSdpaAttention(\n",
      "          (q_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)\n",
      "          (k_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)\n",
      "          (v_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)\n",
      "          (o_proj): Linear4bit(in_features=4096, out_features=3072, bias=False)\n",
      "          (rotary_emb): GemmaRotaryEmbedding()\n",
      "        )\n",
      "        (mlp): GemmaMLP(\n",
      "          (gate_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)\n",
      "          (up_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)\n",
      "          (down_proj): Linear4bit(in_features=24576, out_features=3072, bias=False)\n",
      "          (act_fn): GELUActivation()\n",
      "        )\n",
      "        (input_layernorm): GemmaRMSNorm()\n",
      "        (post_attention_layernorm): GemmaRMSNorm()\n",
      "      )\n",
      "    )\n",
      "    (norm): GemmaRMSNorm()\n",
      "  )\n",
      "  (lm_head): Linear(in_features=3072, out_features=256000, bias=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bitsandbytes as bnb\n",
    "def find_all_linear_names(model):\n",
    "  cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)\n",
    "  lora_module_names = set()\n",
    "  for name, module in model.named_modules():\n",
    "    if isinstance(module, cls):\n",
    "      names = name.split('.')\n",
    "      lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n",
    "    if 'lm_head' in lora_module_names: # needed for 16-bit\n",
    "      lora_module_names.remove('lm_head')\n",
    "  return list(lora_module_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['o_proj', 'q_proj', 'up_proj', 'v_proj', 'k_proj', 'down_proj', 'gate_proj']\n"
     ]
    }
   ],
   "source": [
    "modules = find_all_linear_names(model)\n",
    "print(modules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "    r=64,\n",
    "    lora_alpha=32,\n",
    "    target_modules=modules,\n",
    "    lora_dropout=0.05,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\"\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainable: 200015872 | total: 8737696768 | Percentage: 2.2891%\n"
     ]
    }
   ],
   "source": [
    "trainable, total = model.get_nb_trainable_parameters()\n",
    "print(f\"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Run the training!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setting the training arguments:\n",
    "* for the reason of demo, we just ran it for few steps (100) just to showcase how to use this integration with existing tools on the HF ecosystem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import transformers\n",
    "\n",
    "# tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "\n",
    "# trainer = transformers.Trainer(\n",
    "#     model=model,\n",
    "#     train_dataset=train_data,\n",
    "#     eval_dataset=test_data,\n",
    "#     args=transformers.TrainingArguments(\n",
    "#         per_device_train_batch_size=1,\n",
    "#         gradient_accumulation_steps=4,\n",
    "#         warmup_steps=0.03,\n",
    "#         max_steps=100,\n",
    "#         learning_rate=2e-4,\n",
    "#         fp16=True,\n",
    "#         logging_steps=1,\n",
    "#         output_dir=\"outputs_mistral_b_finance_finetuned_test\",\n",
    "#         optim=\"paged_adamw_8bit\",\n",
    "#         save_strategy=\"epoch\",\n",
    "#     ),\n",
    "#     data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-Tuning with qLora and Supervised Fine-Tuning\n",
    "\n",
    "We're ready to fine-tune our model using qLora. For this tutorial, we'll use the `SFTTrainer` from the `trl` library for supervised fine-tuning. Ensure that you've installed the `trl` library as mentioned in the prerequisites."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#new code using SFTTrainer\n",
    "import transformers\n",
    "\n",
    "from trl import SFTTrainer\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    train_dataset=train_data,\n",
    "    eval_dataset=test_data,\n",
    "    dataset_text_field=\"prompt\",\n",
    "    peft_config=lora_config,\n",
    "    args=transformers.TrainingArguments(\n",
    "        per_device_train_batch_size=1,\n",
    "        gradient_accumulation_steps=4,\n",
    "        warmup_steps=0.03,\n",
    "        max_steps=100,\n",
    "        learning_rate=2e-4,\n",
    "        logging_steps=1,\n",
    "        output_dir=\"outputs\",\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        save_strategy=\"epoch\",\n",
    "    ),\n",
    "    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lets start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/adithya/miniconda3/envs/gemma-venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [100/100 03:06, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>10.299600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>6.640600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>7.763200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>4.142200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.619800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>4.840600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.886800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>1.334500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>1.177300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.343300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>1.306000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>1.122000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>1.337300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>1.083400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>1.399100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>1.172800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.838900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>1.027200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.847500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.065700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.944100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.782900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>1.030100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.911900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.868400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.768300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.845200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>1.036400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.766000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.741500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.738800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.779200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.833000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.851800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.824800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.757800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.804500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.962900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.716000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>1.027000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.743100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.932800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.623100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.671600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.807100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.957200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.643900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.867800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.810700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.871100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.628500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.946400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.918600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.920300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.746100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.914800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.705700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.883300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>1.016300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.583200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>61</td>\n",
       "      <td>0.872000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>62</td>\n",
       "      <td>0.617600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63</td>\n",
       "      <td>0.858700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>0.955700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>65</td>\n",
       "      <td>0.854500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66</td>\n",
       "      <td>0.778000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>67</td>\n",
       "      <td>0.733200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>68</td>\n",
       "      <td>0.871200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>69</td>\n",
       "      <td>0.847700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.567400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71</td>\n",
       "      <td>1.078200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>0.945800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>73</td>\n",
       "      <td>0.762500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>74</td>\n",
       "      <td>0.618800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75</td>\n",
       "      <td>0.803600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>76</td>\n",
       "      <td>0.848400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>77</td>\n",
       "      <td>0.504300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>78</td>\n",
       "      <td>0.685300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>79</td>\n",
       "      <td>0.700700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.636400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>81</td>\n",
       "      <td>0.832700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>82</td>\n",
       "      <td>0.614600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>83</td>\n",
       "      <td>0.899000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>84</td>\n",
       "      <td>0.623800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>85</td>\n",
       "      <td>0.637200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>86</td>\n",
       "      <td>0.697500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>87</td>\n",
       "      <td>0.770200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>88</td>\n",
       "      <td>0.800000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>89</td>\n",
       "      <td>0.748600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.897100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>91</td>\n",
       "      <td>0.718800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>92</td>\n",
       "      <td>0.703900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>93</td>\n",
       "      <td>0.635600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>94</td>\n",
       "      <td>0.649800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>95</td>\n",
       "      <td>0.627200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>96</td>\n",
       "      <td>0.785200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>97</td>\n",
       "      <td>0.808200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>98</td>\n",
       "      <td>0.720500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>99</td>\n",
       "      <td>0.656600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.650900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=100, training_loss=1.1844707012176514, metrics={'train_runtime': 188.6612, 'train_samples_per_second': 2.12, 'train_steps_per_second': 0.53, 'total_flos': 3947993787666432.0, 'train_loss': 1.1844707012176514, 'epoch': 0.0})"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Share adapters on the 🤗 Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_model = \"gemma-Code-Instruct-Finetune-test\" #Name of the model you will be pushing to huggingface model hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.model.save_pretrained(new_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_id,\n",
    "    low_cpu_mem_usage=True,\n",
    "    return_dict=True,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map={\"\": 0},\n",
    ")\n",
    "merged_model= PeftModel.from_pretrained(base_model, new_model)\n",
    "merged_model= merged_model.merge_and_unload()\n",
    "\n",
    "# Save the merged model\n",
    "merged_model.save_pretrained(\"merged_model\",safe_serialization=True)\n",
    "tokenizer.save_pretrained(\"merged_model\")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"right\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Push the model and tokenizer to the Hugging Face Model Hub\n",
    "merged_model.push_to_hub(new_model, use_temp_dir=False)\n",
    "tokenizer.push_to_hub(new_model, use_temp_dir=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test out Finetuned Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "  user\n",
      "  Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
      "  code the fibonacci series in python using reccursion\n",
      "  \n",
      "model\n",
      "  \n",
      "\n",
      "   a Python program to code the Fibonacci series using recursion. Here's a solution:\n",
      "\n",
      "```python\n",
      "def fibonacci(n):\n",
      "  \"\"\"Calculates the nth Fibonacci number using recursion.\n",
      "\n",
      "  Args:\n",
      "    n: The index of the Fibonacci number to calculate.\n",
      "\n",
      "  Returns:\n",
      "    The nth Fibonacci number.\n",
      "  \"\"\"\n",
      "\n",
      "  # Base case: The first two Fibonacci numbers are 0 and 1.\n",
      "  if n <= 1:\n",
      "    return n\n",
      "\n",
      "  # Recursive case: Otherwise, calculate the nth Fibonacci number by adding the previous two numbers.\n",
      "  else:\n",
      "    return fibonacci(n-1) + fibonacci(n-2)\n",
      "\n",
      "\n",
      "# Print the Fibonacci numbers.\n",
      "for i in range(10):\n",
      "  print(fibonacci(i))\n",
      "```\n",
      "\n",
      "**Explanation:**\n",
      "\n",
      "* The `fibonacci` function takes an integer `n` as input.\n",
      "* If `n` is less than or equal to 1, it returns `n` itself, as the base case.\n",
      "* Otherwise, it calculates `n`-th Fibonacci number recursively by adding the previous two numbers.\n",
      "* The function calls itself with smaller and smaller values of `n` until it reaches the base case.\n",
      "\n",
      "**Example Usage:**\n",
      "\n",
      "```python\n",
      "# Print the Fibonacci numbers from 0 to 10.\n",
      "for i in range(10):\n",
      "  print(fibonacci(i))\n",
      "```\n",
      "\n",
      "**Output:**\n",
      "\n",
      "```\n",
      "0\n",
      "1\n",
      "1\n",
      "2\n",
      "3\n",
      "5\n",
      "8\n",
      "13\n",
      "21\n",
      "34\n",
      "```\n",
      "\n",
      "**Note:**\n",
      "\n",
      "* This program calculates Fibonacci numbers recursively, which means that it may not be very efficient for large values of `n` as it repeats calculations.\n",
      "* For more efficient Fibonacci number calculation, consider using an iterative approach rather than recursion.\n"
     ]
    }
   ],
   "source": [
    "result = get_completion(query=\"code the fibonacci series in python using reccursion\", model=merged_model, tokenizer=tokenizer)\n",
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gemma-venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
